{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "Sentencepiece python module example",
      "version": "0.3.2",
      "provenance": [],
      "collapsed_sections": [],
      "include_colab_link": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/google/sentencepiece/blob/master/python/sentencepiece_python_module_example.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "metadata": {
        "id": "T9BDzLVkUFT4",
        "colab_type": "text"
      },
      "cell_type": "markdown",
      "source": [
        "# Sentencepiece python module\n",
        "\n",
        "\n",
        "This notebook describes comprehensive examples of sentencepiece Python module. \n",
        "Since Python module calls C++ API through SWIG,  this document is also useful for developing C++ client."
      ]
    },
    {
      "metadata": {
        "id": "kIgXb6P2Yg6g",
        "colab_type": "text"
      },
      "cell_type": "markdown",
      "source": [
        "## Install and data preparation\n",
        "\n",
        "We use the small training data (botchan.txt) in this example. \n",
        "([Botchan](https://en.wikipedia.org/wiki/Botchan) is a novel written by Natsume Sōseki in 1906.  The sample is English-translated one.)"
      ]
    },
    {
      "metadata": {
        "id": "SUcAbKnRVAv6",
        "colab_type": "code",
        "outputId": "d9710f00-25e7-4fe2-d3fb-a24ab974a18f",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 240
        }
      },
      "cell_type": "code",
      "source": [
        "!pip install sentencepiece\n",
        "!wget https://raw.githubusercontent.com/google/sentencepiece/master/data/botchan.txt"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Requirement already satisfied: sentencepiece in /usr/local/lib/python3.6/dist-packages (0.1.81)\n",
            "--2019-03-27 21:17:13--  https://raw.githubusercontent.com/google/sentencepiece/master/data/botchan.txt\n",
            "Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 151.101.0.133, 151.101.64.133, 151.101.128.133, ...\n",
            "Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|151.101.0.133|:443... connected.\n",
            "HTTP request sent, awaiting response... 200 OK\n",
            "Length: 278779 (272K) [text/plain]\n",
            "Saving to: ‘botchan.txt.1’\n",
            "\n",
            "botchan.txt.1       100%[===================>] 272.25K  --.-KB/s    in 0.05s   \n",
            "\n",
            "2019-03-27 21:17:13 (5.50 MB/s) - ‘botchan.txt.1’ saved [278779/278779]\n",
            "\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "metadata": {
        "id": "-k5KbVgiYae-",
        "colab_type": "text"
      },
      "cell_type": "markdown",
      "source": [
        "## Basic  end-to-end example\n",
        "\n"
      ]
    },
    {
      "metadata": {
        "id": "ee9W6wGnVteW",
        "colab_type": "code",
        "outputId": "c8cbe6d9-d052-4e6f-b5ab-270445a84f93",
        "cellView": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 91
        }
      },
      "cell_type": "code",
      "source": [
        "import sentencepiece as spm\n",
        "\n",
        "# train sentencepiece model from `botchan.txt` and makes `m.model` and `m.vocab`\n",
        "# `m.vocab` is just a reference. not used in the segmentation.\n",
        "spm.SentencePieceTrainer.train('--input=botchan.txt --model_prefix=m --vocab_size=2000')\n",
        "\n",
        "# makes segmenter instance and loads the model file (m.model)\n",
        "sp = spm.SentencePieceProcessor()\n",
        "sp.load('m.model')\n",
        "\n",
        "# encode: text => id\n",
        "print(sp.encode_as_pieces('This is a test'))\n",
        "print(sp.encode_as_ids('This is a test'))\n",
        "\n",
        "# decode: id => text\n",
        "print(sp.decode_pieces(['▁This', '▁is', '▁a', '▁t', 'est']))\n",
        "print(sp.decode_ids([209, 31, 9, 375, 586]))"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "['▁This', '▁is', '▁a', '▁t', 'est']\n",
            "[209, 31, 9, 375, 586]\n",
            "This is a test\n",
            "This is a test\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "metadata": {
        "id": "4vHnQbBOltZo",
        "colab_type": "code",
        "outputId": "9bb1ecaf-2883-494c-e34b-5616efac126a",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 147
        }
      },
      "cell_type": "code",
      "source": [
        "# returns vocab size\n",
        "print(sp.get_piece_size())\n",
        "\n",
        "# id <=> piece conversion\n",
        "print(sp.id_to_piece(209))\n",
        "print(sp.piece_to_id('▁This'))\n",
        "\n",
        "# returns 0 for unknown tokens (we can change the id for UNK)\n",
        "print(sp.piece_to_id('__MUST_BE_UNKNOWN__'))\n",
        "\n",
        "# <unk>, <s>, </s> are defined by default. Their ids are (0, 1, 2)\n",
        "# <s> and </s> are defined as 'control' symbol.\n",
        "for id in range(3):\n",
        "  print(sp.id_to_piece(id), sp.is_control(id))"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "2000\n",
            "▁This\n",
            "209\n",
            "0\n",
            "<unk> False\n",
            "<s> True\n",
            "</s> True\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "metadata": {
        "id": "MRv6EeC2Y2PE",
        "colab_type": "text"
      },
      "cell_type": "markdown",
      "source": [
        "## Loads model from byte stream\n",
        "\n",
        "Sentencepiece's model file is just a serialized [protocol buffer](https://developers.google.com/protocol-buffers/). We can instantiate sentencepiece processor from byte object with **load_from_serialized_proto** method."
      ]
    },
    {
      "metadata": {
        "id": "0Bdi9SuxYAud",
        "colab_type": "code",
        "outputId": "b1566541-288e-4aa3-9c75-e494d0ab276a",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 36
        }
      },
      "cell_type": "code",
      "source": [
        "import tensorflow as tf\n",
        "\n",
        "# Assumes that m.model is stored in non-Posix file system.\n",
        "serialized_model_proto = tf.gfile.GFile('m.model', 'rb').read()\n",
        "\n",
        "sp = spm.SentencePieceProcessor()\n",
        "sp.load_from_serialized_proto(serialized_model_proto)\n",
        "\n",
        "print(sp.encode_as_pieces('this is a test'))"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "['▁this', '▁is', '▁a', '▁t', 'est']\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "metadata": {
        "id": "imfPyYlVZmxz",
        "colab_type": "text"
      },
      "cell_type": "markdown",
      "source": [
        "## User defined and control symbols\n",
        "\n",
        "We can define special tokens (symbols) to tweak the DNN behavior through the tokens.   Typical examples are  [BERT](https://arxiv.org/abs/1810.04805)'s special symbols., e.g., [SEP] and [CLS].\n",
        "\n",
        "There are two types of special tokens:\n",
        "\n",
        "- **user defined symbols**: Always treated as one token in any context. These symbols can appear in the input sentence. \n",
        "- **control symbol**:  We only reserve ids for these tokens. Even if these tokens appear in the input text, they are not handled as one token. User needs to insert ids explicitly after encoding.\n",
        "\n",
        "For experimental purpose, user defined symbols are easier to use since user can change the behavior just by modifying the input text. However,  we want to use control symbols in the production setting in order to avoid users from tweaking the behavior by feeding these special symbols in their input text."
      ]
    },
    {
      "metadata": {
        "id": "dngckiPMcWbA",
        "colab_type": "code",
        "outputId": "e52883b1-6452-4a90-8802-945c9ac9d5b5",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 110
        }
      },
      "cell_type": "code",
      "source": [
        "## Example of user defined symbols\n",
        "spm.SentencePieceTrainer.train('--input=botchan.txt --model_prefix=m_user --user_defined_symbols=<sep>,<cls> --vocab_size=2000')\n",
        "\n",
        "sp_user = spm.SentencePieceProcessor()\n",
        "sp_user.load('m_user.model')\n",
        "\n",
        "# ids are reserved in both mode.\n",
        "# <unk>=0, <s>=1, </s>=2, <sep>=3, <cls>=4\n",
        "# user defined symbols allow these symbol to apper in the text.\n",
        "print(sp_user.encode_as_pieces('this is a test<sep> hello world<cls>'))\n",
        "print(sp_user.piece_to_id('<sep>'))  # 3\n",
        "print(sp_user.piece_to_id('<cls>'))  # 4\n",
        "print('3=', sp_user.decode_ids([3]))  # decoded to <sep>\n",
        "print('4=', sp_user.decode_ids([4]))  # decoded to <cls>"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "['▁this', '▁is', '▁a', '▁t', 'est', '<sep>', '▁he', 'll', 'o', '▁world', '<cls>']\n",
            "3\n",
            "4\n",
            "3= <sep>\n",
            "4= <cls>\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "metadata": {
        "id": "5awRJ0y1oYm-",
        "colab_type": "code",
        "outputId": "a5fa1ef9-ee5f-4f7d-b6bd-b611979b7350",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 110
        }
      },
      "cell_type": "code",
      "source": [
        "## Example of control symbols\n",
        "spm.SentencePieceTrainer.train('--input=botchan.txt --model_prefix=m_ctrl --control_symbols=<sep>,<cls> --vocab_size=2000')\n",
        "\n",
        "sp_ctrl = spm.SentencePieceProcessor()\n",
        "sp_ctrl.load('m_ctrl.model')\n",
        "\n",
        "# control symbols just reserve ids.\n",
        "print(sp_ctrl.encode_as_pieces('this is a test<sep> hello world<cls>'))\n",
        "print(sp_ctrl.piece_to_id('<sep>'))  # 3\n",
        "print(sp_ctrl.piece_to_id('<cls>'))  # 4\n",
        "print('3=', sp_ctrl.decode_ids([3]))  # decoded to empty\n",
        "print('4=', sp_ctrl.decode_ids([4]))  # decoded to empty"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "['▁this', '▁is', '▁a', '▁t', 'est', '<', 'se', 'p', '>', '▁he', 'll', 'o', '▁world', '<', 'c', 'l', 's', '>']\n",
            "3\n",
            "4\n",
            "3= \n",
            "4= \n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "metadata": {
        "id": "8ppZck91s0rq",
        "colab_type": "text"
      },
      "cell_type": "markdown",
      "source": [
        " BOS/EOS (&lt;s&gt;, &lt;/s&gt;) are defined as control symbols, but we can define them as user defined symbols."
      ]
    },
    {
      "metadata": {
        "id": "PQoZ8paVhcEL",
        "colab_type": "code",
        "outputId": "c17d2ae2-17cd-4875-997c-fe66dba23dde",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 54
        }
      },
      "cell_type": "code",
      "source": [
        "spm.SentencePieceTrainer.train('--input=botchan.txt --model_prefix=m_bos_as_user --user_defined_symbols=<s>,</s> --vocab_size=2000')\n",
        "\n",
        "sp = spm.SentencePieceProcessor()\n",
        "sp.load('m.model')\n",
        "print(sp.encode_as_pieces('<s> hello</s>'))   # <s>,</s> are segmented. (default behavior)\n",
        "\n",
        "sp = spm.SentencePieceProcessor()\n",
        "sp.load('m_bos_as_user.model')\n",
        "print(sp.encode_as_pieces('<s> hello</s>'))   # <s>,</s> are handled as one token."
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "['▁', '<', 's', '>', '▁he', 'll', 'o', '</', 's', '>']\n",
            "['▁', '<s>', '▁he', 'll', 'o', '</s>']\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "metadata": {
        "id": "RZ2GjO5Tmjk9",
        "colab_type": "text"
      },
      "cell_type": "markdown",
      "source": [
        "## Manipulating BOS/EOS/EOS/PAD symbols\n",
        "\n",
        "BOS, EOS, UNK, and PAD ids can be obtained with **bos_id()**, **eos_id()**,  **unk_id()**, and **pad_id()** methods. We can explicitly insert these ids as follows."
      ]
    },
    {
      "metadata": {
        "id": "UtFQqK3tmp7G",
        "colab_type": "code",
        "outputId": "03e82bec-be40-4574-ab62-c66cdc3f28a6",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 128
        }
      },
      "cell_type": "code",
      "source": [
        "spm.SentencePieceTrainer.train('--input=botchan.txt --model_prefix=m --vocab_size=2000')\n",
        "\n",
        "sp = spm.SentencePieceProcessor()\n",
        "sp.load('m.model')\n",
        "\n",
        "print('bos=', sp.bos_id())\n",
        "print('eos=', sp.eos_id())\n",
        "print('unk=', sp.unk_id())\n",
        "print('pad=', sp.pad_id())  # disabled by default\n",
        "\n",
        "\n",
        "print(sp.encode_as_ids('Hello world'))\n",
        "\n",
        "# Prepend or append bos/eos ids.\n",
        "print([sp.bos_id()] + sp.encode_as_ids('Hello world') + [sp.eos_id()])"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "bos= 1\n",
            "eos= 2\n",
            "unk= 0\n",
            "pad= -1\n",
            "[12, 1828, 1038]\n",
            "[1, 12, 1828, 1038, 2]\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "metadata": {
        "id": "2CLaMlHUh4Dk",
        "colab_type": "text"
      },
      "cell_type": "markdown",
      "source": [
        "## Changing the vocab id and surface representation of UNK/BOS/EOS/PAD symbols\n",
        "\n",
        "By default, UNK/BOS/EOS/PAD tokens and their ids are defined as follows:\n",
        "\n",
        "|token|UNK|BOS|EOS|PAD|\n",
        "---|---\n",
        "|surface|&lt;unk&gt;|&lt;s&gt;|&lt;/s&gt;|&lt;pad&gt;|\n",
        "|id|0|1|2|undefined (-1)|\n",
        "\n",
        "\n",
        "We can change these mappings with **--{unk|bos|eos|pad}_id** and **--{unk|bos|eos|pad}_piece** flags."
      ]
    },
    {
      "metadata": {
        "id": "PKn1f3eih_We",
        "colab_type": "code",
        "outputId": "a9349dd9-0cd7-49d7-8ed3-5b35658b3116",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 91
        }
      },
      "cell_type": "code",
      "source": [
        "spm.SentencePieceTrainer.train('--input=botchan.txt --vocab_size=2000 --model_prefix=m --pad_id=0 --unk_id=1 --bos_id=2 --eos_id=3 --pad_piece=[PAD] --unk_piece=[UNK] --bos_piece=[BOS] --eos_piece=[EOS]')\n",
        "sp = spm.SentencePieceProcessor()\n",
        "sp.load('m.model')\n",
        "\n",
        "\n",
        "for id in range(4):\n",
        "    print(sp.id_to_piece(id), sp.is_control(id))"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "[PAD] True\n",
            "[UNK] False\n",
            "[BOS] True\n",
            "[EOS] True\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "metadata": {
        "id": "jPVnkpQstMOw",
        "colab_type": "text"
      },
      "cell_type": "markdown",
      "source": [
        "When -1 is set,  this special symbol is disabled. UNK must not be undefined."
      ]
    },
    {
      "metadata": {
        "id": "59jHBemKlU8b",
        "colab_type": "code",
        "outputId": "893ce083-7a31-468c-a49c-33d4ba962014",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 73
        }
      },
      "cell_type": "code",
      "source": [
        "# Disable BOS/EOS\n",
        "spm.SentencePieceTrainer.train('--input=botchan.txt --vocab_size=2000 --model_prefix=m --bos_id=-1 --eos_id=-1')\n",
        "sp = spm.SentencePieceProcessor()\n",
        "sp.load('m.model')\n",
        "\n",
        "# <s>, </s> are UNK.\n",
        "print(sp.unk_id())\n",
        "print(sp.piece_to_id('<s>'))\n",
        "print(sp.piece_to_id('</s>'))"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "0\n",
            "0\n",
            "0\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "metadata": {
        "id": "frGQ-lhkU03z",
        "colab_type": "text"
      },
      "cell_type": "markdown",
      "source": [
        "UNK id is decoded into U+2047\t(⁇) by default.  We can change UNK surface with **--unk_surface=&lt;STR&gt;** flag."
      ]
    },
    {
      "metadata": {
        "id": "S34JDUUAVe41",
        "colab_type": "code",
        "outputId": "25b0349b-4ddd-40f3-a7fb-1f4810b645c2",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 54
        }
      },
      "cell_type": "code",
      "source": [
        "spm.SentencePieceTrainer.train('--input=botchan.txt --vocab_size=2000 --model_prefix=m')\n",
        "sp = spm.SentencePieceProcessor()\n",
        "sp.load('m.model')\n",
        "print(sp.decode_ids([sp.unk_id()]))   # default is U+2047\n",
        "\n",
        "spm.SentencePieceTrainer.train('--input=botchan.txt --vocab_size=2000 --model_prefix=m --unk_surface=__UNKNOWN__')\n",
        "sp = spm.SentencePieceProcessor()\n",
        "sp.load('m.model')\n",
        "print(sp.decode_ids([sp.unk_id()])) "
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            " ⁇ \n",
            "__UNKNOWN__\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "metadata": {
        "id": "5vDXA3Q6kjCS",
        "colab_type": "text"
      },
      "cell_type": "markdown",
      "source": [
        "## Sampling and nbest segmentation for subword regularization\n",
        "\n",
        "When **--model_type=unigram** (default) is used,  we can perform sampling and n-best segmentation for data augmentation. See subword regularization paper [[kudo18]](https://www.google.com/search?q=subword+regularization&rlz=1CAASUL_enJP841&oq=subword+regu&aqs=chrome.0.69i59j69i61j69i57j69i61l2j0.1571j0j7&sourceid=chrome&ie=UTF-8) for more detail."
      ]
    },
    {
      "metadata": {
        "id": "nSQp93qflZO3",
        "colab_type": "code",
        "outputId": "d5b45b62-2789-4879-b80b-f441e8b594af",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 388
        }
      },
      "cell_type": "code",
      "source": [
        "spm.SentencePieceTrainer.train('--input=botchan.txt --model_prefix=m --vocab_size=2000')\n",
        "\n",
        "# Can obtain different segmentations per request.\n",
        "# There are two hyperparamenters for sampling (nbest_size and inverse temperature). see the paper [kudo18] for detail.\n",
        "for n in range(10):\n",
        "  print(sp.sample_encode_as_pieces('hello world', -1, 0.1))\n",
        "  \n",
        "for n in range(10):\n",
        "  print(sp.sample_encode_as_ids('hello world', -1, 0.1))"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "['▁', 'h', 'e', 'll', 'o', '▁w', 'o', 'r', 'l', 'd']\n",
            "['▁he', 'l', 'l', 'o', '▁world']\n",
            "['▁he', 'l', 'l', 'o', '▁w', 'or', 'l', 'd']\n",
            "['▁', 'he', 'l', 'l', 'o', '▁world']\n",
            "['▁', 'he', 'll', 'o', '▁w', 'o', 'r', 'l', 'd']\n",
            "['▁', 'he', 'll', 'o', '▁world']\n",
            "['▁he', 'll', 'o', '▁world']\n",
            "['▁', 'he', 'll', 'o', '▁world']\n",
            "['▁he', 'll', 'o', '▁w', 'o', 'r', 'l', 'd']\n",
            "['▁', 'h', 'e', 'l', 'l', 'o', '▁w', 'o', 'r', 'l', 'd']\n",
            "[12, 489, 57, 57, 38, 1246, 57, 20]\n",
            "[28, 98, 38, 1038]\n",
            "[12, 489, 98, 38, 12, 151, 105, 57, 20]\n",
            "[12, 489, 98, 38, 1038]\n",
            "[28, 98, 38, 254, 105, 57, 20]\n",
            "[12, 489, 98, 38, 12, 151, 38, 46, 57, 20]\n",
            "[28, 57, 57, 38, 1038]\n",
            "[28, 98, 38, 1038]\n",
            "[12, 96, 351, 57, 38, 1038]\n",
            "[28, 98, 38, 1038]\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "metadata": {
        "id": "9V1snUZdlb_v",
        "colab_type": "code",
        "outputId": "bae2ca99-aca3-4013-c240-53b09a0bb684",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 74
        }
      },
      "cell_type": "code",
      "source": [
        "# get 10 best\n",
        "print(sp.nbest_encode_as_pieces('hello world', 10))\n",
        "print(sp.nbest_encode_as_ids('hello world', 10))"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "[['▁he', 'll', 'o', '▁world'], ['▁he', 'l', 'l', 'o', '▁world'], ['▁', 'he', 'll', 'o', '▁world'], ['▁', 'h', 'e', 'll', 'o', '▁world'], ['▁he', 'll', 'o', '▁wor', 'l', 'd'], ['▁', 'he', 'l', 'l', 'o', '▁world'], ['▁', 'h', 'el', 'l', 'o', '▁world'], ['▁he', 'll', 'o', '▁w', 'or', 'l', 'd'], ['▁', 'h', 'e', 'l', 'l', 'o', '▁world'], ['▁he', 'l', 'l', 'o', '▁wor', 'l', 'd']]\n",
            "[[28, 98, 38, 1038], [28, 57, 57, 38, 1038], [12, 489, 98, 38, 1038], [12, 96, 25, 98, 38, 1038], [28, 98, 38, 1246, 57, 20], [12, 489, 57, 57, 38, 1038], [12, 96, 351, 57, 38, 1038], [28, 98, 38, 254, 105, 57, 20], [12, 96, 25, 57, 57, 38, 1038], [28, 57, 57, 38, 1246, 57, 20]]\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "metadata": {
        "id": "cH6cxuVNcDKh",
        "colab_type": "text"
      },
      "cell_type": "markdown",
      "source": [
        "## BPE (Byte pair encoding) model\n",
        "\n",
        "Sentencepiece supports BPE (byte-pair-encoding) for subword segmentation with **--model_type=bpe** flag.   We do not find empirical differences in translation quality between BPE and unigram model, but unigram model can perform sampling and n-best segmentation. See subword regularization paper [[kudo18]](https://www.google.com/search?q=subword+regularization&rlz=1CAASUL_enJP841&oq=subword+regu&aqs=chrome.0.69i59j69i61j69i57j69i61l2j0.1571j0j7&sourceid=chrome&ie=UTF-8) for more detail."
      ]
    },
    {
      "metadata": {
        "id": "MNQxuX4Mc0KY",
        "colab_type": "code",
        "outputId": "a6ed3e99-46c3-4c5e-dc8e-80394e17e363",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 73
        }
      },
      "cell_type": "code",
      "source": [
        "spm.SentencePieceTrainer.train('--input=botchan.txt --model_prefix=m_bpe --vocab_size=2000 --model_type=bpe')\n",
        "sp_bpe = spm.SentencePieceProcessor()\n",
        "sp_bpe.load('m_bpe.model')\n",
        "\n",
        "print('*** BPE ***')\n",
        "print(sp_bpe.encode_as_pieces('thisisatesthelloworld'))\n",
        "print(sp_bpe.nbest_encode_as_pieces('hello world', 5))  # returns an empty list."
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "*** BPE ***\n",
            "['▁this', 'is', 'at', 'est', 'he', 'llow', 'or', 'ld']\n",
            "[]\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "metadata": {
        "id": "EZrj1zCkvK8v",
        "colab_type": "code",
        "outputId": "8d5421d9-9b89-440d-c91a-f03c081947d9",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 93
        }
      },
      "cell_type": "code",
      "source": [
        "spm.SentencePieceTrainer.train('--input=botchan.txt --model_prefix=m_unigram --vocab_size=2000 --model_type=unigram')\n",
        "sp_unigram = spm.SentencePieceProcessor()\n",
        "sp_unigram.load('m_unigram.model')\n",
        "\n",
        "print('*** Unigram ***')\n",
        "print(sp_unigram.encode_as_pieces('thisisatesthelloworld'))\n",
        "print(sp_unigram.nbest_encode_as_pieces('thisisatesthelloworld', 5))"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "*** Unigram ***\n",
            "['▁this', 'is', 'ate', 's', 'the', 'llow', 'or', 'l', 'd']\n",
            "[['▁this', 'is', 'ate', 's', 'the', 'llow', 'or', 'l', 'd'], ['▁this', 'i', 's', 'ate', 's', 'the', 'llow', 'or', 'l', 'd'], ['▁this', 'is', 'ate', 'st', 'he', 'llow', 'or', 'l', 'd'], ['▁this', 'is', 'at', 'es', 'the', 'llow', 'or', 'l', 'd'], ['▁this', 'is', 'at', 'est', 'he', 'llow', 'or', 'l', 'd']]\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "metadata": {
        "id": "yJXHCoAHoZWg",
        "colab_type": "text"
      },
      "cell_type": "markdown",
      "source": [
        "## Character and word model\n",
        "\n",
        "Sentencepiece supports character and word segmentation with **--model_type=char** and **--model_type=character** flags.\n",
        "\n",
        "In `word` segmentation, sentencepiece just segments tokens with whitespaces, so the input text must be pre-tokenized.\n",
        "We can apply different segmentation algorithm transparently without changing pre/post processors."
      ]
    },
    {
      "metadata": {
        "id": "pOAOmQGQpBhg",
        "colab_type": "code",
        "outputId": "0bcfa075-4231-4299-e117-c4d02dae0872",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 54
        }
      },
      "cell_type": "code",
      "source": [
        "spm.SentencePieceTrainer.train('--input=botchan.txt --model_prefix=m_char --model_type=char --vocab_size=400')\n",
        "\n",
        "sp_char = spm.SentencePieceProcessor()\n",
        "sp_char.load('m_char.model')\n",
        "\n",
        "print(sp_char.encode_as_pieces('this is a test.'))\n",
        "print(sp_char.encode_as_ids('this is a test.'))"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "['▁', 't', 'h', 'i', 's', '▁', 'i', 's', '▁', 'a', '▁', 't', 'e', 's', 't', '.']\n",
            "[3, 5, 10, 9, 11, 3, 9, 11, 3, 7, 3, 5, 4, 11, 5, 23]\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "metadata": {
        "id": "uzBiPAm4ljor",
        "colab_type": "code",
        "outputId": "299083c5-e453-4a4b-86eb-c1abd0aa34e3",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 54
        }
      },
      "cell_type": "code",
      "source": [
        "spm.SentencePieceTrainer.train('--input=botchan.txt --model_prefix=m_word --model_type=word --vocab_size=2000')\n",
        "\n",
        "sp_word = spm.SentencePieceProcessor()\n",
        "sp_word.load('m_word.model')\n",
        "\n",
        "print(sp_word.encode_as_pieces('this is a test.'))  # '.' will not be one token.\n",
        "print(sp_word.encode_as_ids('this is a test.'))"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "['▁this', '▁is', '▁a', '▁test.']\n",
            "[31, 17, 8, 0]\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "metadata": {
        "id": "UZvkFnw9pt-D",
        "colab_type": "text"
      },
      "cell_type": "markdown",
      "source": [
        "## Text normalization\n",
        "\n",
        "Sentencepiece provides the following general pre-defined normalization rules. We can change the normalizer with **--normaliation_rule_name=&lt;NAME&gt;** flag.\n",
        "\n",
        "- **nmt_nfkc**: NFKC normalization with some additional normalization around spaces. (default)\n",
        "- **nfkc: original**: NFKC normalization.\n",
        "- **nmt_nfkc_cf**: nmt_nfkc + Unicode case folding (mostly lower casing)\n",
        "- **nfkc_cf**: nfkc + Unicode case folding.\n",
        "- **identity**: no normalization\n",
        "\n"
      ]
    },
    {
      "metadata": {
        "id": "jSJiwIeBqFcO",
        "colab_type": "code",
        "outputId": "c35ec700-dbb9-4601-9d00-6c04a5eef147",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 36
        }
      },
      "cell_type": "code",
      "source": [
        "import sentencepiece as spm\n",
        "\n",
        "# NFKC normalization and lower casing.\n",
        "spm.SentencePieceTrainer.train('--input=botchan.txt --model_prefix=m --vocab_size=2000 --normalization_rule_name=nfkc_cf')\n",
        "\n",
        "sp = spm.SentencePieceProcessor()\n",
        "sp.load('m.model')\n",
        "print(sp.encode_as_pieces('ＨＥＬＬＯ　ＷＯＲＬＤ.'))  # lower casing and normalization"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "['▁', 'hello', '▁world', '.']\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "metadata": {
        "id": "Fp1QiTjprER4",
        "colab_type": "text"
      },
      "cell_type": "markdown",
      "source": [
        "The normalization is performed with user-defined string-to-string mappings and leftmost longest matching.\n",
        "We can also define the custom normalization rules as TSV file. The TSV files for pre-defined normalziation rules can be found in the data directory ([sample](https://raw.githubusercontent.com/google/sentencepiece/master/data/nfkc.tsv)). The normalization rule is compiled into FST and embedded in the model file. We don't need to specify the normalization configuration in the segmentation phase.\n",
        "\n",
        "Here's the example of custom normalization. The TSV file is fed with **--normalization_rule_tsv=&lt;FILE&gt;** flag."
      ]
    },
    {
      "metadata": {
        "id": "xHM5aGYTrfXg",
        "colab_type": "code",
        "outputId": "b9652d0d-6e03-486b-fe69-710153e4c906",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 110
        }
      },
      "cell_type": "code",
      "source": [
        "def tocode(s):                                                                               \n",
        "    out = []                                                                                 \n",
        "    for c in s:                                                                              \n",
        "        out.append(str(hex(ord(c))).replace('0x', 'U+'))                                     \n",
        "    return ' '.join(out)          \n",
        "\n",
        "# TSV format:  source Unicode code points <tab> target code points\n",
        "# normalize \"don't => do not,  I'm => I am\"\n",
        "with open('normalization_rule.tsv', 'w') as f:\n",
        "  f.write(tocode(\"I'm\") + '\\t' + tocode(\"I am\") + '\\n')\n",
        "  f.write(tocode(\"don't\") + '\\t' + tocode(\"do not\") + '\\n')\n",
        "\n",
        "print(open('normalization_rule.tsv', 'r').read())\n",
        "\n",
        "spm.SentencePieceTrainer.train('--input=botchan.txt --model_prefix=m --vocab_size=2000 --normalization_rule_tsv=normalization_rule.tsv')\n",
        "\n",
        "sp = spm.SentencePieceProcessor()\n",
        "# m.model embeds the normalization rule compiled into an FST.\n",
        "sp.load('m.model')\n",
        "print(sp.encode_as_pieces(\"I'm busy\"))  # normalzied to `I am busy'\n",
        "print(sp.encode_as_pieces(\"I don't know it.\"))  # normalized to 'I do not know it.'\n"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "U+49 U+27 U+6d\tU+49 U+20 U+61 U+6d\n",
            "U+64 U+6f U+6e U+27 U+74\tU+64 U+6f U+20 U+6e U+6f U+74\n",
            "\n",
            "['▁I', '▁am', '▁bu', 's', 'y']\n",
            "['▁I', '▁do', '▁not', '▁know', '▁it', '.']\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "metadata": {
        "id": "YdSx1bizvSbH",
        "colab_type": "text"
      },
      "cell_type": "markdown",
      "source": [
        "## Randomizing training data\n",
        "\n",
        "Sentencepiece loads all the lines of training data into memory to train the model.  However, larger training data increases the training time and memory usage, though they are liner to the training data. When **--input_sentence_size=&lt;SIZE&gt;** is specified,  Sentencepiece randomly samples &lt;SIZE&gt; lines from the whole training data.   **--shuffle_input_sentence=false** disables the random shuffle and takes the first &lt;SIZE&gt; lines."
      ]
    },
    {
      "metadata": {
        "id": "FZ089HOXwppS",
        "colab_type": "code",
        "outputId": "607ec114-f443-44e1-ce88-002b5e6deb36",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 36
        }
      },
      "cell_type": "code",
      "source": [
        "spm.SentencePieceTrainer.train('--input=botchan.txt --model_prefix=m --vocab_size=2000 --input_sentence_size=1000')\n",
        "\n",
        "sp = spm.SentencePieceProcessor()\n",
        "sp.load('m.model')\n",
        "\n",
        "sp.encode_as_pieces('this is a test.')"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "['▁this', '▁is', '▁a', '▁t', 'est', '.']"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 35
        }
      ]
    },
    {
      "metadata": {
        "id": "07FMNoCmglil",
        "colab_type": "text"
      },
      "cell_type": "markdown",
      "source": [
        "## Vocabulary restriction\n",
        "\n",
        "We can encode the text only using the tokens spececified with **set_vocabulary** method.  The background of this feature is described in [subword-nmt page](https://github.com/rsennrich/subword-nmt#best-practice-advice-for-byte-pair-encoding-in-nmt)."
      ]
    },
    {
      "metadata": {
        "id": "H2soU1eZhdH_",
        "colab_type": "code",
        "outputId": "b1e725d4-a80b-4741-ef1f-1ef0cd3d2665",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 73
        }
      },
      "cell_type": "code",
      "source": [
        "spm.SentencePieceTrainer.train('--input=botchan.txt --model_prefix=m --vocab_size=2000')\n",
        "\n",
        "sp = spm.SentencePieceProcessor()\n",
        "sp.load('m.model')\n",
        "\n",
        "print(sp.encode_as_pieces('this is a test.'))\n",
        "\n",
        "# Gets all tokens as Python list.\n",
        "vocabs = [sp.id_to_piece(id) for id in range(sp.get_piece_size())]\n",
        "\n",
        "# Aggregates the frequency of each token in the training data.\n",
        "freq = {}\n",
        "with open('botchan.txt', 'r') as f:\n",
        "    for line in f:\n",
        "        line = line.rstrip()\n",
        "        for piece in sp.encode_as_pieces(line):\n",
        "            freq.setdefault(piece, 0)\n",
        "            freq[piece] += 1\n",
        "            \n",
        "# only uses the token appearing more than 1000 times in the training data.\n",
        "vocabs = list(filter(lambda x : x in freq and freq[x] > 1000, vocabs))\n",
        "sp.set_vocabulary(vocabs)\n",
        "print(sp.encode_as_pieces('this is a test.'))\n",
        "\n",
        "# reset the restriction\n",
        "sp.reset_vocabulary()\n",
        "print(sp.encode_as_pieces('this is a test.'))"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "['▁this', '▁is', '▁a', '▁t', 'est', '.']\n",
            "['▁', 't', 'h', 'i', 's', '▁', 'i', 's', '▁a', '▁', 't', 'e', 's', 't', '.']\n",
            "['▁this', '▁is', '▁a', '▁t', 'est', '.']\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "metadata": {
        "id": "z8rQLqCTHk40",
        "colab_type": "text"
      },
      "cell_type": "markdown",
      "source": [
        "## Extracting crossing-words pieces\n",
        "\n",
        "Sentencepieces does not extract pieces crossing multiple words (here the `word` means the space delimited tokens). The piece will never contain the whitespace marker (_) in the middle.\n",
        "\n",
        "**--split_by_whtespace=false** disables this restriction and allows to extract pieces crossing multiple words.  In CJK (Chinese/Japanese/Korean), this flag will not affect the final segmentation results so much as  words are not tokenized with whitespaces in CJK."
      ]
    },
    {
      "metadata": {
        "id": "Lf5Fs_pPIKif",
        "colab_type": "code",
        "outputId": "fc1cff92-bb97-4ab9-c52a-8e3c47d1d88f",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 147
        }
      },
      "cell_type": "code",
      "source": [
        "import re\n",
        "\n",
        "spm.SentencePieceTrainer.train('--input=botchan.txt --model_prefix=m --vocab_size=2000 --split_by_whitespace=false')\n",
        "\n",
        "sp = spm.SentencePieceProcessor()\n",
        "sp.load('m.model')\n",
        "\n",
        "# Gets all tokens as Python list.\n",
        "vocabs = [sp.id_to_piece(id) for id in range(sp.get_piece_size())]\n",
        "\n",
        "for piece in vocabs[0:500]:\n",
        "    if re.match('\\w+▁\\w+', piece):\n",
        "        print(piece)"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "ed▁to\n",
            "s▁of\n",
            "ing▁the\n",
            "s▁and\n",
            "ed▁by\n",
            "ed▁the\n",
            "ed▁me\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "metadata": {
        "id": "WWjA7yOX1Rlg",
        "colab_type": "text"
      },
      "cell_type": "markdown",
      "source": [
        "## Training sentencepiece model from the word list with frequency\n",
        "\n",
        "We can train the sentencepiece model from the pair of &lt;word, frequency&gt;. First, you make a TSV file where the first column is the word and the second column is the frequency. Then, feed this TSV file with **--input_format=tsv** flag. Note that when feeding TSV as training data, we implicitly assume that **--split_by_whtespace=true**."
      ]
    },
    {
      "metadata": {
        "id": "T7F349Sd2Bzg",
        "colab_type": "code",
        "outputId": "3453d1f2-2614-4258-ed7d-5a2e230b29ec",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 36
        }
      },
      "cell_type": "code",
      "source": [
        "freq={}\n",
        "with open('botchan.txt', 'r') as f:\n",
        "  for line in f:\n",
        "    line = line.rstrip()\n",
        "    for piece in line.split():\n",
        "      freq.setdefault(piece, 0)\n",
        "      freq[piece] += 1\n",
        "            \n",
        "with open('word_freq_list.tsv', 'w') as f:\n",
        "  for k, v in freq.items():\n",
        "    f.write('%s\\t%d\\n' % (k, v))\n",
        "  \n",
        "\n",
        "import sentencepiece as spm\n",
        "\n",
        "spm.SentencePieceTrainer.train('--input=word_freq_list.tsv --input_format=tsv --model_prefix=m --vocab_size=2000')\n",
        "sp = spm.SentencePieceProcessor()\n",
        "sp.load('m.model')\n",
        "\n",
        "print(sp.encode_as_pieces('this is a test.'))"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "['▁this', '▁is', '▁a', '▁t', 'est', '.']\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "metadata": {
        "id": "fiWMoTpA-pHx",
        "colab_type": "text"
      },
      "cell_type": "markdown",
      "source": [
        "## Getting byte offsets of tokens\n",
        "\n",
        "Sentencepiece keeps track of byte offset (span) of each token, which is useful for highlighting the token on top of unnormalized text.\n",
        "\n",
        "We first need to install protobuf module and sentencepiece_pb2.py as the byte offsets and all other meta data for segementation are encoded in protocol buffer.\n",
        "**encode_as_serialized_proto** method resturns serialized SentencePieceText proto. You can get the deserialized object by calling ParseFromString method.\n",
        "\n",
        "The definition of SentencePieceText proto is found [here](https://github.com/google/sentencepiece/blob/3be3f2e11e2bb923c579c6be5e7335809341587f/src/sentencepiece.proto#L23).\n"
      ]
    },
    {
      "metadata": {
        "id": "JTYrvL6KkmVK",
        "colab_type": "code",
        "outputId": "1459a127-7aed-4296-f8e6-0c5e76c6c3d6",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 277
        }
      },
      "cell_type": "code",
      "source": [
        "!pip install protobuf\n",
        "!wget https://raw.githubusercontent.com/google/sentencepiece/master/python/sentencepiece_pb2.py"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Requirement already satisfied: protobuf in /usr/local/lib/python3.6/dist-packages (3.7.0)\n",
            "Requirement already satisfied: six>=1.9 in /usr/local/lib/python3.6/dist-packages (from protobuf) (1.11.0)\n",
            "Requirement already satisfied: setuptools in /usr/local/lib/python3.6/dist-packages (from protobuf) (40.8.0)\n",
            "--2019-03-27 21:42:35--  https://raw.githubusercontent.com/google/sentencepiece/master/python/sentencepiece_pb2.py\n",
            "Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 151.101.0.133, 151.101.64.133, 151.101.128.133, ...\n",
            "Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|151.101.0.133|:443... connected.\n",
            "HTTP request sent, awaiting response... 200 OK\n",
            "Length: 7382 (7.2K) [text/plain]\n",
            "Saving to: ‘sentencepiece_pb2.py.1’\n",
            "\n",
            "sentencepiece_pb2.p 100%[===================>]   7.21K  --.-KB/s    in 0s      \n",
            "\n",
            "2019-03-27 21:42:35 (52.3 MB/s) - ‘sentencepiece_pb2.py.1’ saved [7382/7382]\n",
            "\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "metadata": {
        "id": "KdRy9sEvk7zw",
        "colab_type": "code",
        "outputId": "80b1a4e5-8cbb-46bc-9549-e24444328f79",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 463
        }
      },
      "cell_type": "code",
      "source": [
        "import sentencepiece_pb2\n",
        "import sentencepiece as spm\n",
        "\n",
        "spm.SentencePieceTrainer.train('--input=botchan.txt --model_prefix=m --vocab_size=2000')\n",
        "\n",
        "sp = spm.SentencePieceProcessor()\n",
        "sp.load('m.model')\n",
        "\n",
        "# One best result\n",
        "spt = sentencepiece_pb2.SentencePieceText()\n",
        "spt.ParseFromString(sp.encode_as_serialized_proto('ｈｅｌｌｏ')) # Full width hello\n",
        "\n",
        "# begin/end (offsets) are pointing to the original input.\n",
        "print(spt)\n",
        "\n",
        "# Nbest results\n",
        "nspt = sentencepiece_pb2.NBestSentencePieceText()\n",
        "nspt.ParseFromString(sp.nbest_encode_as_serialized_proto('ｈｅｌｌｏ', 5))\n",
        "# print(nspt)"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "text: \"\\357\\275\\210\\357\\275\\205\\357\\275\\214\\357\\275\\214\\357\\275\\217\"\n",
            "pieces {\n",
            "  piece: \"\\342\\226\\201he\"\n",
            "  id: 28\n",
            "  surface: \"\\357\\275\\210\\357\\275\\205\"\n",
            "  begin: 0\n",
            "  end: 6\n",
            "}\n",
            "pieces {\n",
            "  piece: \"ll\"\n",
            "  id: 98\n",
            "  surface: \"\\357\\275\\214\\357\\275\\214\"\n",
            "  begin: 6\n",
            "  end: 12\n",
            "}\n",
            "pieces {\n",
            "  piece: \"o\"\n",
            "  id: 38\n",
            "  surface: \"\\357\\275\\217\"\n",
            "  begin: 12\n",
            "  end: 15\n",
            "}\n",
            "\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "489"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 50
        }
      ]
    }
  ]
}
